package net.apexes.wsonrpc.client.support.websocket;

import javax.net.SocketFactory;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.SSLSocketFactory;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.URI;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;

public class WebSocketClient {

    private static final String THREAD_BASE_NAME = "WebSocket";
    private static final AtomicInteger CLIENT_COUNT = new AtomicInteger(0);

    public enum State {
        NONE, CONNECTING, CONNECTED, DISCONNECTING, DISCONNECTED
    }

    private static ThreadFactory threadFactory = Executors.defaultThreadFactory();
    private static ThreadInitializer intializer = Thread::setName;

    public static void setThreadFactory(ThreadFactory threadFactory, ThreadInitializer intializer) {
        WebSocketClient.threadFactory = threadFactory;
        WebSocketClient.intializer = intializer;
    }

    static ThreadFactory getThreadFactory() {
        return threadFactory;
    }

    static ThreadInitializer getIntializer() {
        return intializer;
    }

    static final byte OPCODE_NONE = 0x0;
    static final byte OPCODE_TEXT = 0x1;
    static final byte OPCODE_BINARY = 0x2;
    static final byte OPCODE_CLOSE = 0x8;
    static final byte OPCODE_PING = 0x9;
    static final byte OPCODE_PONG = 0xA;

    private volatile State state = State.NONE;
    private volatile Socket socket = null;

    private WebSocketEventHandler eventHandler = null;

    private final URI url;
    private int soTimeout;
    private int connectTimeout;

    private final WebSocketReceiver receiver;
    private final WebSocketWriter writer;
    private final WebSocketHandshake handshake;
    private final int clientId = CLIENT_COUNT.incrementAndGet();

    private final Thread innerThread;

    private SocketFactory socketFactory;
    private HostnameVerifier hostnameVerifier;

    /**
     * Create a websocket to connect to a given server
     *
     * @param url The URL of a websocket server
     */
    public WebSocketClient(URI url) {
        this(url, null);
    }

    /**
     * Create a websocket to connect to a given server. Include protocol in websocket handshake
     *
     * @param url      The URL of a websocket server
     * @param protocol The protocol to include in the handshake. If null, it will be omitted
     */
    public WebSocketClient(URI url, String protocol) {
        this(url, protocol, null);
    }

    /**
     * Create a websocket to connect to a given server. Include the given protocol in the handshake, as well as any
     * extra HTTP headers specified. Useful if you would like to include a User-Agent or other header
     *
     * @param url          The URL of a websocket server
     * @param protocol     The protocol to include in the handshake. If null, it will be omitted
     * @param extraHeaders Any extra HTTP headers to be included with the initial request. Pass null if not extra headers
     *                     are requested
     */
    public WebSocketClient(URI url, String protocol, Map<String, String> extraHeaders) {
        innerThread = getThreadFactory().newThread(new Runnable() {
            @Override
            public void run() {
                runReader();
            }
        });
        this.url = url;
        handshake = new WebSocketHandshake(url, protocol, extraHeaders);
        receiver = new WebSocketReceiver(this);
        writer = new WebSocketWriter(this, THREAD_BASE_NAME, clientId);
    }

    public void setSocketFactory(SocketFactory socketFactory) {
        this.socketFactory = socketFactory;
    }

    public void setHostnameVerifier(HostnameVerifier hostnameVerifier) {
        this.hostnameVerifier = hostnameVerifier;
    }

    /**
     * Must be called before connect(). Set the support for all websocket-related events.
     *
     * @param eventHandler The support to be triggered with relevant events
     */
    public void setEventHandler(WebSocketEventHandler eventHandler) {
        this.eventHandler = eventHandler;
    }

    public WebSocketEventHandler getEventHandler() {
        return this.eventHandler;
    }

    public State getState() {
        return state;
    }

    /**
     * Start up the socket. This is non-blocking, it will fire up the threads used by the library and then trigger the
     * onOpen support once the connection is established.
     */
    public synchronized void connect(int connectTimeout, int soTimeout) {
        this.connectTimeout = connectTimeout;
        this.soTimeout = soTimeout;
        if (state != State.NONE) {
            eventHandler.onError(new WebSocketException("connect() already called"));
            close();
            return;
        }
        getIntializer().setName(innerThread, THREAD_BASE_NAME + "Reader-" + clientId);
        state = State.CONNECTING;
        innerThread.start();
    }

    /**
     * Send a TEXT message over the socket
     *
     * @param data The text payload to be sent
     */
    public void send(String data) {
        send(OPCODE_TEXT, data.getBytes(StandardCharsets.UTF_8));
    }

    /**
     * Send a BINARY message over the socket
     *
     * @param data The binary payload to be sent
     */
    public void send(byte[] data) {
        send(OPCODE_BINARY, data);
    }

    public void ping(byte[] bytes) {
        if (bytes != null) {
            send(OPCODE_PING, bytes);
        } else {
            send(OPCODE_PING, new byte[] {});
        }
    }

    void pong(byte[] data) {
        send(OPCODE_PONG, data);
    }

    private synchronized void send(byte opcode, byte[] data) {
        if (state != State.CONNECTED) {
            // We might have been disconnected on another thread, just report an error
            eventHandler.onError(new WebSocketException("error while sending data: not connected"));
        } else {
            try {
                writer.send(opcode, true, data);
            } catch (IOException e) {
                eventHandler.onError(new WebSocketException("Failed to send frame", e));
                close();
            }
        }
    }

    void handleReceiverError(WebSocketException e) {
        eventHandler.onError(e);
        if (state == State.CONNECTED) {
            close();
        }
        closeSocket();
    }

    protected void onPong(byte[] payload) {
        // NOTE: as a client, we don't expect PONGs. No-op
    }

    /**
     * Close down the socket. Will trigger the onClose support if the socket has not been previously closed.
     */
    public synchronized void close() {
        switch (state) {
            case NONE:
                state = State.DISCONNECTED;
                return;
            case CONNECTING:
                // don't wait for an established connection, just close the tcp socket
                closeSocket();
                return;
            case CONNECTED:
                // This method also shuts down the writer
                // the socket will be closed once the ack for the close was received
                sendCloseHandshake(true);
                return;
            case DISCONNECTING:
                return; // no-op;
            case DISCONNECTED:
                return;  // No-op
        }
    }

    void onCloseOpReceived() {
        closeSocket();
    }

    private synchronized void closeSocket() {
        if (state == State.DISCONNECTED) {
            return;
        }
        receiver.stopit();
        sendCloseHandshake(false);
        if (socket != null) {
            try {
                socket.close();
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        state = State.DISCONNECTED;

        eventHandler.onClose();
    }

    private void sendCloseHandshake(boolean clientClosed) {
        try {
            state = State.DISCONNECTING;
            // Set the stop flag then queue up a message. This ensures that the writer thread
            // will wake up, and since we set the stop flag, it will exit its run loop.
            writer.stopIt();
            writer.send(OPCODE_CLOSE, true, new byte[0]);
        } catch (IOException e) {
            if (clientClosed) {
                eventHandler.onError(new WebSocketException("Failed to send close frame", e));
            }
        } catch (WebSocketException e) {
            if (clientClosed) {
                throw e;
            }
        }
    }

    private Socket createSocket() {
        String scheme = url.getScheme();
        String host = url.getHost();
        int port = url.getPort();

        Socket socket;

        if ("ws".equals(scheme)) {
            if (port == -1) {
                port = 80;
            }
            try {
                socket = new Socket();
                socket.setSoTimeout(soTimeout);
                socket.connect(new InetSocketAddress(host, port), connectTimeout);
            } catch (UnknownHostException uhe) {
                throw new WebSocketException("unknown host: " + host, uhe);
            } catch (IOException ioe) {
                throw new WebSocketException("error while creating socket to " + url, ioe);
            }
        } else if ("wss".equals(scheme)) {
            if (port == -1) {
                port = 443;
            }
            try {
                SocketFactory factory;
                if (socketFactory != null) {
                    factory = socketFactory;
                } else {
                    factory = SSLSocketFactory.getDefault();
                }
                socket = factory.createSocket();
                socket.setSoTimeout(soTimeout);
                socket.connect(new InetSocketAddress(host, port), connectTimeout);
                if (hostnameVerifier != null) {
                    hostnameVerifier.verify(host, ((SSLSocket) socket).getSession());
                }
            } catch (UnknownHostException uhe) {
                throw new WebSocketException("unknown host: " + host, uhe);
            } catch (IOException ioe) {
                throw new WebSocketException("error while creating secure socket to " + url, ioe);
            }
        } else {
            throw new WebSocketException("unsupported protocol: " + scheme);
        }

        return socket;
    }

    /**
     * Blocks until both threads exit. The actual close must be triggered separately. This is just a convenience
     * method to make sure everything shuts down, if desired.
     *
     * @throws InterruptedException
     */
    public void blockClose() throws InterruptedException {
        // If the thread is new, it will never run, since we closed the connection before we actually connected
        if (writer.getInnerThread().getState() != Thread.State.NEW) {
            writer.getInnerThread().join();
        }
        innerThread.join();
    }

    private void runReader() {
        try {
            Socket socket = createSocket();
            synchronized (this) {
                WebSocketClient.this.socket = socket;
                if (WebSocketClient.this.state == State.DISCONNECTED) {
                    // The connection has been closed while creating the socket, close it immediately and return
                    try {
                        WebSocketClient.this.socket.close();
                    } catch (IOException e) {
                        throw new RuntimeException(e);
                    }
                    WebSocketClient.this.socket = null;
                    return;
                }
            }

            DataInputStream input = new DataInputStream(socket.getInputStream());
            OutputStream output = socket.getOutputStream();

            output.write(handshake.getHandshake());

            boolean handshakeComplete = false;
            int len = 1000;
            byte[] buffer = new byte[len];
            int pos = 0;
            ArrayList<String> handshakeLines = new ArrayList<>();

            while (!handshakeComplete) {
                int b = input.read();
                if (b == -1) {
                    throw new WebSocketException("Connection closed before handshake was complete");
                }
                buffer[pos] = (byte) b;
                pos += 1;

                if (buffer[pos - 1] == 0x0A && buffer[pos - 2] == 0x0D) {
                    String line = new String(buffer, StandardCharsets.UTF_8);
                    if ("".equals(line.trim())) {
                        handshakeComplete = true;
                    } else {
                        handshakeLines.add(line.trim());
                    }

                    buffer = new byte[len];
                    pos = 0;
                } else if (pos == 1000) {
                    // This really shouldn't happen, handshake lines are short, but just to be safe...
                    String line = new String(buffer, StandardCharsets.UTF_8);
                    throw new WebSocketException("Unexpected long line in handshake: " + line);
                }
            }

            handshake.verifyServerStatusLine(handshakeLines.get(0));
            handshakeLines.remove(0);

            HashMap<String, String> headers = new HashMap<String, String>();
            for (String line : handshakeLines) {
                String[] keyValue = line.split(": ", 2);
                headers.put(keyValue[0].toLowerCase(Locale.US), keyValue[1]);
            }
            handshake.verifyServerHandshakeHeaders(headers);

            writer.setOutput(output);
            receiver.setInput(input);
            state = State.CONNECTED;
            writer.getInnerThread().start();
            eventHandler.onOpen();
            receiver.run();
        } catch (WebSocketException wse) {
            eventHandler.onError(wse);
        } catch (IOException ioe) {
            eventHandler
                    .onError(new WebSocketException("error while connecting: " + ioe.getMessage(), ioe));
        } finally {
            close();
        }
    }
}
